{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EheA5_j_cEwc"
      },
      "source": [
        "##### Copyright 2021 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YCriMWd-pRTP"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OvRwFTkqcp1e"
      },
      "source": [
        "# Vectorization and XLA compilation\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/google/tf-quant-finance/blob/master/tf_quant_finance/examples/jupyter_notebooks/Vectorization_and_XLA_compilation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/google/tf-quant-finance/blob/master/tf_quant_finance/examples/jupyter_notebooks/Vectorization_and_XLA_compilation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "uG8UAZXjeUhf"
      },
      "outputs": [],
      "source": [
        "#@title Upgrade to TensorFlow 2.5+\n",
        "!pip install --upgrade tensorflow"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "aijwbAA8u_IM"
      },
      "outputs": [],
      "source": [
        "#@title Install TF Quant Finance\n",
        "!pip install tf-quant-finance"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K5lDRRp8Rkyy"
      },
      "source": [
        "\n",
        "In this notebook we use option pricing as an example problem to illustrate several approaches to optimization:\n",
        "\n",
        "* Baseline one-at-a-time computation.\n",
        "* Batched computation.\n",
        "* Vectorized computation.\n",
        "* XLA compilation.\n",
        "\n",
        "\n",
        "If reader is not familiar with the concepts, going through [the training](https://colab.research.google.com/github/google/tf-quant-finance/blob/master/tf_quant_finance/examples/jupyter_notebooks/Introduction_to_TensorFlow_Part_3_-_Advanced_Tensor_Manipulation.ipynb) first is recommended.\n",
        "\n",
        "The results in this colab were obtained with the public GPU runtime:\n",
        "\n",
        "`Runtime -\u003e Change runtime type -\u003e Hardware accelerator -\u003e GPU`\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HjZ9D1DARXUM"
      },
      "outputs": [],
      "source": [
        "#@title Imports\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tf_quant_finance as tff\n",
        "import time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 379,
          "status": "ok",
          "timestamp": 1626954174053,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "QkNkP781Q-1J",
        "outputId": "caaccedf-e948-44b5-a3cd-e65e5118d07b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Thu Jul 22 11:42:53 2021       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   44C    P0    26W /  70W |    224MiB / 15109MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        }
      ],
      "source": [
        "# Check that a GPU is available\n",
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OptbxBwLcNoT"
      },
      "source": [
        "## Sequential pricing\n",
        "\n",
        "Consider a simple example of Black-Scholes American option pricing using binomial tree algorithm. Below we generate data for a 1000 random options\n",
        "and price them all one by one."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MLPyFjwucUfw"
      },
      "outputs": [],
      "source": [
        "# Generate 1000 random call options\n",
        "dtype = tf.float64\n",
        "num_options = 1000\n",
        "spots = 1 + tf.random.stateless_uniform(\n",
        "    [num_options], seed=[4, 2], dtype=dtype)\n",
        "strikes = spots + (0.5 - tf.random.uniform([num_options], dtype=dtype))\n",
        "# All options expire in in 1 year\n",
        "expiry = tf.constant(1.0, dtype=dtype)\n",
        "# Constant discount rates and volatilities\n",
        "discount_rate = tf.constant(0.01, dtype=dtype)\n",
        "volatility = tf.constant(1.0, dtype=dtype)\n",
        "# Randomly assing put/call indicators\n",
        "is_call_options = tf.cast(tf.random.stateless_binomial(\n",
        "    [num_options], seed=[4, 2], counts=True, probs=0.5), tf.bool)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "trm2gG-zcUfx"
      },
      "source": [
        "Let us first price each option in one-by-one manner"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SG6P8r87cUfx"
      },
      "outputs": [],
      "source": [
        "# Wrap option pricing function with tf.function to perform\n",
        "# calculation optimizations\n",
        "option_price = tf.function(tff.black_scholes.option_price_binomial)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 11540,
          "status": "ok",
          "timestamp": 1626958403761,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "TZiCUbWKcUfx",
        "outputId": "c07765c4-a0f0-4e70-a7a3-f13f20f34f88"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "wall time (no batching): 11126.479625701904 ms\n"
          ]
        }
      ],
      "source": [
        "# Pricing with no batching\n",
        "start_time = time.time()\n",
        "prices = []\n",
        "with tf.device('gpu:0'):\n",
        "  for spot, strike, is_call_option in zip(spots, strikes, is_call_options):\n",
        "    prices.append(option_price(\n",
        "        spots=spot, strikes=strike, expiries=expiry,\n",
        "        discount_rates=discount_rate, \n",
        "        volatilities=volatility,\n",
        "        is_american=True,\n",
        "        is_call_options=is_call_option,\n",
        "        dtype=dtype))\n",
        "print(f\"wall time (no batching): {1000 * (time.time() - start_time)} ms\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 346,
          "status": "ok",
          "timestamp": 1626954211359,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "0f0y-WXhcUfy",
        "outputId": "7f4fc59e-dee8-4e42-c2f4-8a3d2ad2ec56"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[\u003ctf.Tensor: shape=(), dtype=float64, numpy=0.32871543038795475\u003e,\n",
              " \u003ctf.Tensor: shape=(), dtype=float64, numpy=0.2948546293901379\u003e,\n",
              " \u003ctf.Tensor: shape=(), dtype=float64, numpy=0.4912944051602193\u003e,\n",
              " \u003ctf.Tensor: shape=(), dtype=float64, numpy=0.7329574846641592\u003e,\n",
              " \u003ctf.Tensor: shape=(), dtype=float64, numpy=0.5549488585678107\u003e]"
            ]
          },
          "execution_count": 9,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Prices of the 1st five options\n",
        "prices[:5]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XQ1yjab9JWjO"
      },
      "source": [
        "## Batching\n",
        "\n",
        "We can significantly improve the performance by pricing all the options in a single batch. Batching means that the inputs to the pricing functions are supplied as a single `Tenor` with a batch dimension. In this case `spots`, `strikes` and `is_call_options` are all `Tensor` of shape `[1000]` and can be \n",
        "supplied to `option_price` function as is. \n",
        "\n",
        "Most of TF ops are vectorized, as are many of the functions in TF Quant Finance, and allow batch calculations. Typically function\n",
        "documentation includes information whether the function supports batching. If it does, batching of the calculation can significantly reduce run time.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V110iVmYDbNJ"
      },
      "outputs": [],
      "source": [
        "# We can compute all the prices in a single batch\n",
        "prices = option_price(\n",
        "      spots=spots, strikes=strikes, expiries=expiry,\n",
        "      discount_rates=discount_rate, \n",
        "      volatilities=volatility,\n",
        "      is_american=True,\n",
        "      is_call_options=is_call_options,\n",
        "      dtype=dtype)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 7025,
          "status": "ok",
          "timestamp": 1626954683759,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "gs45g6lbgull",
        "outputId": "a78fcf92-a023-43d7-b2c1-ccab547740e3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "100 loops, best of 5: 10.6 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%%timeit\n",
        "with tf.device('gpu:0'):\n",
        "  option_price(\n",
        "        spots=spots, strikes=strikes, expiries=expiry,\n",
        "        discount_rates=discount_rate, \n",
        "        volatilities=volatility,\n",
        "        is_american=True,\n",
        "        is_call_options=is_call_options,\n",
        "        dtype=dtype).numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 343,
          "status": "ok",
          "timestamp": 1626954686904,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "B29uk48LMSBs",
        "outputId": "d183d1d1-4cc1-41c3-90ea-e9cb7cb1b3dc"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "\u003ctf.Tensor: shape=(5,), dtype=float64, numpy=array([0.32871543, 0.29485463, 0.49129441, 0.73295748, 0.55494886])\u003e"
            ]
          },
          "execution_count": 14,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Prices of the 1st five options\n",
        "prices[:5]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "81pddCJhcVQ3"
      },
      "source": [
        "## Vectorization\n",
        "\n",
        "Sometimes batching is not implemented for arguments of a function (that is, some of the arguments are missing the batch shape). In this case, [`tf.vectorized_map`](https://www.tensorflow.org/s/results?q=tf.vectorized_map) might be extremely useful as it can parallelize calculations along the 0-th dimension of the input `Tensor`s. Below we demonstrate how to use vectorized map for the option pricing example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BzWW-HFPcXPZ"
      },
      "outputs": [],
      "source": [
        "# Define a pricing function for the inputs to parallelize on\n",
        "def pricer_fn(packed_args):\n",
        "  # Arguments are packed as a tuple\n",
        "  spot, strike, is_call_option = packed_args\n",
        "  return option_price(\n",
        "        spots=spot, strikes=strike, expiries=expiry,\n",
        "        discount_rates=discount_rate, \n",
        "        volatilities=volatility,\n",
        "        is_american=True,\n",
        "        is_call_options=is_call_option,\n",
        "        dtype=dtype)\n",
        "  \n",
        "# Wrap the calculation with a tf.function to enable calculation optimization\n",
        "@tf.function\n",
        "def vectorized_pricer(spots, strikes, is_call_options):\n",
        "  # Call pricer_fn on for each element of spots, strikes and is_call_options\n",
        "  return tf.vectorized_map(pricer_fn, (spots, strikes, is_call_options))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PWsAKnOWgogY"
      },
      "outputs": [],
      "source": [
        "prices = vectorized_pricer(spots, strikes, is_call_options)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 2277,
          "status": "ok",
          "timestamp": 1626955288766,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "igaR5Ft3euAx",
        "outputId": "52b3a479-884c-4318-c498-be57a8421761"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "10 loops, best of 5: 30.3 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%%timeit\n",
        "with tf.device('gpu:0'):\n",
        "  vectorized_pricer(spots, strikes, is_call_options).numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 6,
          "status": "ok",
          "timestamp": 1626955289819,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "f1KMmdlPE8Qc",
        "outputId": "aa293a9c-f7ae-4cfb-ae0d-3705a66dd885"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "\u003ctf.Tensor: shape=(5,), dtype=float64, numpy=array([0.32871543, 0.29485463, 0.49129441, 0.73295748, 0.55494886])\u003e"
            ]
          },
          "execution_count": 35,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Prices of the 1st five options\n",
        "prices[:5]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g0Y1PcO2NcwE"
      },
      "source": [
        "## XLA compilation\n",
        "\n",
        "Sometimes speed performance of a calculation is unsatisfactory. This happens, for example, when an underlying computation graph consists of a large number of simple calculations and TensorFlow overhead becomes a bottleneck (when running a graph, TensorFlow needs to orchestrate op execution, which can take substatial amount of time).\n",
        "\n",
        "XLA compiler takes a different approach by generating an LLVM IR targeting the device (CPU/GPU/TPU) from the computational graph.\n",
        "\n",
        "Refer to the [official XLA page](https://www.tensorflow.org/xla) for more details. XLA Architecture details can be found [here](https://www.tensorflow.org/xla/architecture).\n",
        "\n",
        "Below we demonstrate how to use XLA Just-in-time (JIT) compiler to price European options under Heston model using Attari approximation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MvCP5sTbUTNi"
      },
      "outputs": [],
      "source": [
        "# We will price 1000 options\n",
        "\n",
        "num_options = 1000\n",
        "\n",
        "# Generate data\n",
        "dtype = np.float64\n",
        "mean_reversion = tf.random.uniform(\n",
        "    shape=[num_options], minval=0.1, maxval=10.0, dtype=dtype) \n",
        "thetas = tf.random.uniform(shape=[num_options], minval=0.1, maxval=.5,\n",
        "                           dtype=dtype) \n",
        "variances = tf.random.uniform(shape=[num_options], minval=0.01, maxval=0.5,\n",
        "                           dtype=dtype) \n",
        "discount_factors = tf.random.uniform(\n",
        "    shape=[num_options], minval=0.8, maxval=0.99, dtype=dtype) \n",
        "expiries = tf.random.uniform(\n",
        "    shape=[num_options], minval=0.1, maxval=10.0, dtype=dtype) \n",
        "forwards = 10.0\n",
        "spots = forwards * discount_factors\n",
        "volvol = tf.random.uniform(\n",
        "    shape=[num_options], minval=0.05, maxval=0.9,\n",
        "    dtype=dtype) \n",
        "\n",
        "strikes = tf.random.uniform(\n",
        "    shape=[num_options], minval=9, maxval=11,dtype=dtype) \n",
        "\n",
        "rhos = tf.random.uniform(\n",
        "    shape=[num_options], minval=-0.8, maxval=0.8, dtype=dtype) \n",
        "\n",
        "# Randomly assing put/call indicators\n",
        "is_call_options = tf.cast(tf.random.stateless_binomial(\n",
        "    [num_options], seed=[4, 2], counts=True, probs=0.5), tf.bool)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QKZ6NzumTCBc"
      },
      "outputs": [],
      "source": [
        "# Wrap the pricing function with a tf.function\n",
        "european_option_price = tf.function(\n",
        "    tff.models.heston.approximations.european_option_price)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cpHyaj1hI_Nd"
      },
      "source": [
        "Check out the difference betwen GPU and CPU pricing speed (note there are 2vCPUs in a public colab). \n",
        "It does not seem that the GPU has done a good job."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 20868,
          "status": "ok",
          "timestamp": 1626956239856,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "oLi_mtsqUhGX",
        "outputId": "6584ff53-3618-4079-c960-892f55c6a635"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "10 loops, best of 5: 394 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%%timeit -n 10\n",
        "with tf.device('cpu:0'):\n",
        "  european_option_price(\n",
        "        mean_reversion=mean_reversion,\n",
        "        theta=thetas,\n",
        "        volvol=volvol,\n",
        "        rho=rhos,\n",
        "        variances=variances,\n",
        "        spots=spots,\n",
        "        expiries=expiries,\n",
        "        strikes=strikes,\n",
        "        discount_factors=discount_factors,\n",
        "        is_call_options=is_call_options)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 14857,
          "status": "ok",
          "timestamp": 1626956964450,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "sHbL8i9uTm7t",
        "outputId": "aa78e10b-7e85-4369-eaf9-8e53d949b679"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "10 loops, best of 5: 277 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%%timeit -n 10\n",
        "with tf.device('gpu:0'):\n",
        "  european_option_price(\n",
        "        mean_reversion=mean_reversion,\n",
        "        theta=thetas,\n",
        "        volvol=volvol,\n",
        "        rho=rhos,\n",
        "        variances=variances,\n",
        "        spots=spots,\n",
        "        expiries=expiries,\n",
        "        strikes=strikes,\n",
        "        discount_factors=discount_factors,\n",
        "        is_call_options=is_call_options).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TCbmsqZbVgoJ"
      },
      "source": [
        "Much better performance can be achieved using XLA's just in time compilation. This can be enabled by passing the argument `jit_compile=True` when calling `tf.function`.."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vW3NXg0FTUOC"
      },
      "outputs": [],
      "source": [
        "european_option_price_xla = tf.function(\n",
        "    tff.models.heston.approximations.european_option_price,\n",
        "    jit_compile=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 3061,
          "status": "ok",
          "timestamp": 1626956995249,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "agAJ6jb5J0ug",
        "outputId": "b047422b-0ac6-4c61-9a5e-9ec187c776ca"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The slowest run took 14.82 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
            "20 loops, best of 5: 8.34 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%%timeit -n 20\n",
        "with tf.device('gpu:0'):\n",
        "  european_option_price_xla(\n",
        "        mean_reversion=mean_reversion,\n",
        "        theta=thetas,\n",
        "        volvol=volvol,\n",
        "        rho=rhos,\n",
        "        variances=variances,\n",
        "        spots=spots,\n",
        "        expiries=expiries,\n",
        "        strikes=strikes,\n",
        "        discount_factors=discount_factors,\n",
        "        is_call_options=is_call_options).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mcUtMqfxKKe7"
      },
      "source": [
        "This is much better! The calculation now runs much faster. Sometimes we will receive a warning that `The slowest run took x times longer than the fastest`.\n",
        "This is because on the first run the actual compilation has to be performed (thus, just-in-time compilation). For the successive runs for a new set of inputs, the compiled function is used.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 1535,
          "status": "ok",
          "timestamp": 1626957605669,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "LlW7kG-SM3ef",
        "outputId": "0cc52592-4f86-4a05-8749-4e7063df0256"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Median time to price a batch of 1k options: 8.356451988220215 ms\n"
          ]
        }
      ],
      "source": [
        "# Repeat pricing 100k options with 1000 options per batch for\n",
        "times = []\n",
        "prices = []\n",
        "for _ in range(100):\n",
        "  # Generate new inputs\n",
        "  dtype = np.float64\n",
        "  num_options = 1000\n",
        "  mean_reversion = tf.random.uniform(\n",
        "      shape=[num_options], minval=0.1, maxval=10.0, dtype=dtype) \n",
        "  thetas = tf.random.uniform(shape=[num_options], minval=0.1, maxval=.5,\n",
        "                            dtype=dtype) \n",
        "  variances = tf.random.uniform(shape=[num_options], minval=0.01, maxval=0.5,\n",
        "                            dtype=dtype) \n",
        "  discount_factors = tf.random.uniform(\n",
        "      shape=[num_options], minval=0.8, maxval=0.99, dtype=dtype) \n",
        "  expiries = tf.random.uniform(\n",
        "      shape=[num_options], minval=0.1, maxval=10.0, dtype=dtype) \n",
        "  forwards = 10.0\n",
        "  spots = forwards * discount_factors\n",
        "  volvol = tf.random.uniform(\n",
        "      shape=[num_options], minval=0.05, maxval=0.9,\n",
        "      dtype=dtype) \n",
        "\n",
        "  strikes = tf.random.uniform(\n",
        "      shape=[num_options], minval=9, maxval=11,dtype=dtype) \n",
        "\n",
        "  rhos = tf.random.uniform(\n",
        "      shape=[num_options], minval=-0.8, maxval=0.8, dtype=dtype) \n",
        "\n",
        "  # Randomly assing put/call indicators\n",
        "  is_call_options = tf.cast(tf.random.stateless_binomial(\n",
        "      [num_options], seed=[4, 2], counts=True, probs=0.5), tf.bool)\n",
        "\n",
        "  # Pricing the options for the new inputs using XLA-compiled function\n",
        "  start_time = time.time()\n",
        "  with tf.device('gpu:0'):\n",
        "    prices.append(european_option_price_xla(\n",
        "        mean_reversion=mean_reversion,\n",
        "        theta=thetas,\n",
        "        volvol=volvol,\n",
        "        rho=rhos,\n",
        "        variances=variances,\n",
        "        spots=spots,\n",
        "        expiries=expiries,\n",
        "        strikes=strikes,\n",
        "        discount_factors=discount_factors,\n",
        "        is_call_options=is_call_options).numpy())\n",
        "  end_time = time.time()\n",
        "  times.append(1000 * (end_time - start_time))\n",
        "print(f\"Median time to price a batch of 1k options: {np.median(times)} ms\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2FNZpuNjvrYj"
      },
      "source": [
        "Note that XLA compilation can be used for CPU devices as well, and can result in significant speed up compared to the non-compiled code. For the option pricing example, we opserve performance improvement of around 15%."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 34539,
          "status": "ok",
          "timestamp": 1627033408758,
          "user": {
            "displayName": "Cyril Chimisov",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s64",
            "userId": "02803093032097482871"
          },
          "user_tz": -60
        },
        "id": "0fkdGYSpvnTQ",
        "outputId": "b6a42ff4-ed69-48f5-c753-1dcecb6a7589"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "20 loops, best of 5: 329 ms per loop\n"
          ]
        }
      ],
      "source": [
        "%%timeit -n 20\n",
        "with tf.device('cpu:0'):\n",
        "  european_option_price_xla(\n",
        "        mean_reversion=mean_reversion,\n",
        "        theta=thetas,\n",
        "        volvol=volvol,\n",
        "        rho=rhos,\n",
        "        variances=variances,\n",
        "        spots=spots,\n",
        "        expiries=expiries,\n",
        "        strikes=strikes,\n",
        "        discount_factors=discount_factors,\n",
        "        is_call_options=is_call_options).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-N7-1lVuO0cm"
      },
      "source": [
        "**NB** \n",
        "* XLA comes with Ahead-of-time (AOT) mode as well. See the [official documentation](https://www.tensorflow.org/xla/tfcompile) for details.\n",
        "* Not all function can be XLA-compiled. Some TensorFlow ops are not yet supported by XLA. Also, at the moment all shapes must be known at a runtime. Therefore, at the moment XLA compilation fails for the binomial option pricing algorithm above (we change tree size at each iteration of the algorithm).\n",
        "* XLA-compiled code is not guaranteed to perform better than the non-compiled one.\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Vectorization_and_XLA_compilation.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
